{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Rewrite-Retrieve-Read\n",
    "\n",
    "You can also run this notebook online [on Noteable.io](https://app.noteable.io/published/d2cb020b-8341-4cff-b898-be33d3c62e21).\n",
    "\n",
    "**Rewrite-Retrieve-Read** is a method proposed in the paper [Query Rewriting for Retrieval-Augmented Large Language Models](https://arxiv.org/pdf/2305.14283.pdf)\n",
    "\n",
    "> Because the original query can not be always optimal to retrieve for the LLM, especially in the real world... we first prompt an LLM to rewrite the queries, then conduct retrieval-augmented reading\n",
    "\n",
    "We show how you can easily do that with LangChain Expression Language in this notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Baseline\n",
    "\n",
    "Baseline RAG (**Retrieve-and-read**) can be done like the following:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Deno.env.set(\"OPENAI_API_KEY\", \"\");\n",
    "// Deno.env.set(\"TAVILY_API_KEY\", \"\");\n",
    "\n",
    "import { PromptTemplate } from \"npm:langchain@0.0.172/prompts\";\n",
    "import { ChatOpenAI } from \"npm:langchain@0.0.172/chat_models/openai\";\n",
    "import { StringOutputParser } from \"npm:langchain@0.0.172/schema/output_parser\";\n",
    "import { RunnableSequence, RunnablePassthrough } from \"npm:langchain@0.0.172/schema/runnable\";\n",
    "import { TavilySearchAPIRetriever } from \"npm:langchain@0.0.172/retrievers/tavily_search_api\";\n",
    "import type { Document } from \"npm:langchain@0.0.172/schema/document\";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "const template = `Answer the users question based only on the following context:\n",
    "\n",
    "<context>\n",
    "  {context}\n",
    "</context>\n",
    "\n",
    "Question: {question}`\n",
    "\n",
    "const prompt = PromptTemplate.fromTemplate(template);\n",
    "\n",
    "const model = new ChatOpenAI({\n",
    "  temperature: 0,\n",
    "  openAIApiKey: Deno.env.get(\"OPENAI_API_KEY\"),\n",
    "})\n",
    "\n",
    "const retriever = new TavilySearchAPIRetriever({\n",
    "  k: 3,\n",
    "  apiKey: Deno.env.get(\"TAVILY_API_KEY\"),\n",
    "});\n",
    "\n",
    "const formatDocs = (documents: Document[]) => documents.map((doc) => doc.pageContent).join(\"\\n\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "const chain = RunnableSequence.from([\n",
    "  {\n",
    "    context: retriever.pipe(formatDocs),\n",
    "    question: new RunnablePassthrough()\n",
    "  },\n",
    "  prompt,\n",
    "  model,\n",
    "  new StringOutputParser()\n",
    "]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\u001b[32m\"LangChain is an open source framework designed to simplify the creation of applications using large \"\u001b[39m... 217 more characters"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "const simpleQuery = \"what is langchain?\";\n",
    "\n",
    "await chain.invoke(simpleQuery);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "While this is fine for well formatted queries, it can break down for more complicated queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\u001b[32m'Based on the given context, there is no information about \"langchain\" or any relevance to Sam Bankma'\u001b[39m... 16 more characters"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "const distractedQuery = \"man that sam bankman fried trial was crazy! what is langchain?\";\n",
    "\n",
    "await chain.invoke(distractedQuery);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is because the retriever does a bad job with these \"distracted\" queries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[\n",
       "  Document {\n",
       "    pageContent: \u001b[32m\"Three of Sam Bankman-Fried's closest former associates have turned against him, and their testimony \"\u001b[39m... 98 more characters,\n",
       "    metadata: {\n",
       "      title: \u001b[32m\"3 key witnesses who could send Sam Bankman-Fried to prison for life : NPR\"\u001b[39m,\n",
       "      source: \u001b[32m\"https://www.npr.org/2023/10/21/1207143248/sam-bankman-fried-trial-ftx-crypto-fraud-alameda\"\u001b[39m,\n",
       "      score: \u001b[33m0.95083\u001b[39m,\n",
       "      images: \u001b[1mnull\u001b[22m\n",
       "    }\n",
       "  },\n",
       "  Document {\n",
       "    pageContent: \u001b[32m\"October 18, 20232:46 PM PDTUpdated 2 days ago Former FTX Chief Executive Sam Bankman-Fried, who face\"\u001b[39m... 94 more characters,\n",
       "    metadata: {\n",
       "      title: \u001b[32m\"Sam Bankman-Fried trial jury sees his profane messages about regulators ...\"\u001b[39m,\n",
       "      source: \u001b[32m\"https://www.reuters.com/legal/sam-bankman-fried-trial-jury-sees-his-profane-message-about-regulators\"\u001b[39m... 12 more characters,\n",
       "      score: \u001b[33m0.91945\u001b[39m,\n",
       "      images: \u001b[1mnull\u001b[22m\n",
       "    }\n",
       "  },\n",
       "  Document {\n",
       "    pageContent: \u001b[32m\"FTX founder Sam Bankman-Fried leaves the courthouse following his arraignment in New York City on De\"\u001b[39m... 84 more characters,\n",
       "    metadata: {\n",
       "      title: \u001b[32m\"5 things to know about Sam Bankman-Fried's trial so far : NPR\"\u001b[39m,\n",
       "      source: \u001b[32m\"https://www.npr.org/2023/10/14/1205737325/criminal-mastermind-hapless-dude-sam-bankman-fried-trial-f\"\u001b[39m... 2 more characters,\n",
       "      score: \u001b[33m0.90957\u001b[39m,\n",
       "      images: \u001b[1mnull\u001b[22m\n",
       "    }\n",
       "  }\n",
       "]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "await retriever.invoke(distractedQuery);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Rewrite-Retrieve-Read Implementation\n",
    "\n",
    "The main part is a rewriter to rewrite the search query:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "const rewriteTemplate = `Provide a better search query for \\\n",
    "web search engine to answer the given question, end \\\n",
    "the queries with ’**’. Question: \\\n",
    "{x} Answer:`;\n",
    "const rewritePrompt = PromptTemplate.fromTemplate(rewriteTemplate);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Parser to remove the `**`\n",
    "const _parse = (text) => text.replace(\"**\", \"\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "// rewriter = rewrite_prompt | ChatOpenAI(temperature=0) | StrOutputParser() | _parse\n",
    "const rewriter = RunnableSequence.from([\n",
    "  rewritePrompt,\n",
    "  new ChatOpenAI({ temperature: 0 }),\n",
    "  new StringOutputParser(),\n",
    "  _parse\n",
    "]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\u001b[32m\"What is the definition and purpose of Langchain?\"\u001b[39m"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "await rewriter.invoke({\"x\": distractedQuery})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "const rewriteRetrieveReadChain = RunnableSequence.from([\n",
    "  {\n",
    "    context: RunnableSequence.from([\n",
    "      { x: new RunnablePassthrough() },\n",
    "      rewriter,\n",
    "      retriever,\n",
    "      formatDocs,\n",
    "    ]),\n",
    "    question: new RunnablePassthrough()\n",
    "  },\n",
    "  prompt,\n",
    "  model,\n",
    "  new StringOutputParser()\n",
    "]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\u001b[32m\"LangChain is a framework designed to simplify the creation of applications using large language mode\"\u001b[39m... 293 more characters"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "await rewriteRetrieveReadChain.invoke(distractedQuery);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Deno",
   "language": "typescript",
   "name": "deno"
  },
  "language_info": {
   "file_extension": ".ts",
   "mimetype": "text/x.typescript",
   "name": "typescript",
   "nb_converter": "script",
   "pygments_lexer": "typescript",
   "version": "5.2.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
